Validate seqlens_k against cos_cache bounds in GroupQueryAttention to…#28277
Validate seqlens_k against cos_cache bounds in GroupQueryAttention to…#28277apsonawane wants to merge 8 commits intomainfrom
Conversation
… prevent rotary embedding OOB read
There was a problem hiding this comment.
Pull request overview
This PR hardens the CPU GroupQueryAttention implementation by validating runtime seqlens_k values against the rotary embedding cache length when do_rotary is enabled, preventing out-of-bounds reads from the cos/sin caches that could leak heap data into inference outputs.
Changes:
- Add a per-batch bounds check in
GroupQueryAttention::Compute()rejectingseqlens_k[b] >= cos_cache.shape[0]underdo_rotary_. - Return
INVALID_ARGUMENTwith a descriptive error when the bound is violated.
💡 Add Copilot custom instructions for smarter, more guided reviews. Learn how to get started.
vraspar
left a comment
There was a problem hiding this comment.
Nice security fix -- the gap between CheckRotaryCaches (validates against total_sequence_length) and the actual position ID derivation (from seqlens_k) is a real and subtle bug. Good catch. A few non-blocking nits below.
Nit 1: CUDA EP coverage — This check only protects the CPU EP. The CUDA GQA kernel also uses do_rotary and derives position IDs similarly. Consider moving this validation into CheckInputs() in group_query_attention_helper.h (around line 282, after CheckRotaryCaches) so all EPs get the protection in one place. That said, seqlens_k may be on GPU in the CUDA path making a host-side loop infeasible, so the current placement is pragmatic.
Nit 2: Negative seqlens_k values — The condition seqlens_k_data[b] >= rotary_cache_max_seq will let negative values through, which would also produce OOB position IDs. Adding seqlens_k_data[b] < 0 to the check would catch that here with a clear error message. Likely already caught downstream by the RunRotaryEmbedding validation from #27597, so not urgent.
Nit 3: Multi-batch test — The two test cases cover the key scenarios well. A possible addition: a batch_size=2 test where seqlens_k = {3, 10} (one valid, one OOB) to verify the loop iterates correctly and the error message includes the right batch index (seqlens_k[1] = 10).
Overall this is clean, well-scoped, and well-documented. The error messages are descriptive and the tests directly exercise the vulnerability path.
Description
Validate
seqlens_kvalues againstcos_cache.shape[0]inGroupQueryAttention::Compute()whendo_rotaryis enabled, to prevent out-of-bounds reads in the rotary embedding lookup.Root Cause
CheckRotaryCaches()validatescos_cache.shape[0] >= total_sequence_length, but runtime position IDs are derived fromseqlens_k(a separate, per-batch input). An attacker can settotal_sequence_lengthsmall enough to pass the guard while settingseqlens_k[b]far beyondcos_cache.shape[0], causingposition_id = seqlens_k[b]to index out of bounds into the cos/sin cache. The resulting heap bytes are used as rotation values and propagate into the inference output.Fix
Add an explicit bounds check in
Compute()that rejects anyseqlens_k[b] >= cos_cache.shape[0]before position IDs are computed. This is defense-in-depth alongside the existingRunRotaryEmbeddingposition_ids validation added in #27597.Security
seqlens_kas an inference input. No model modification required.Testing
Verified that crafted inputs with
seqlens_kexceedingcos_cachedimensions now returnINVALID_ARGUMENTinstead of silently producing results containing leaked heap data.